# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)

from typing import Literal, Optional

import cv2
import numpy
from PIL import Image, ImageChops, ImageFilter, ImageOps

from invokeai.app.invocations.baseinvocation import (
    BaseInvocation,
    Classification,
    invocation,
)
from invokeai.app.invocations.constants import IMAGE_MODES
from invokeai.app.invocations.fields import (
    BoundingBoxField,
    ColorField,
    FieldDescriptions,
    ImageField,
    InputField,
    WithBoard,
    WithMetadata,
)
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.image_records.image_records_common import ImageCategory
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.misc import SEED_MAX
from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark
from invokeai.backend.image_util.safety_checker import SafetyChecker


@invocation("show_image", title="Show Image", tags=["image"], category="image", version="1.0.1")
class ShowImageInvocation(BaseInvocation):
    """Displays a provided image using the OS image viewer, and passes it forward in the pipeline."""

    image: ImageField = InputField(description="The image to show")

    def invoke(self, context: InvocationContext) -> ImageOutput:
        image = context.images.get_pil(self.image.image_name)
        image.show()

        # TODO: how to handle failure?

        return ImageOutput(
            image=ImageField(image_name=self.image.image_name),
            width=image.width,
            height=image.height,
        )


@invocation(
    "blank_image",
    title="Blank Image",
    tags=["image"],
    category="image",
    version="1.2.2",
)
class BlankImageInvocation(BaseInvocation, WithMetadata, WithBoard):
    """Creates a blank image and forwards it to the pipeline"""

    width: int = InputField(default=512, description="The width of the image")
    height: int = InputField(default=512, description="The height of the image")
    mode: Literal["RGB", "RGBA"] = InputField(default="RGB", description="The mode of the image")
    color: ColorField = InputField(default=ColorField(r=0, g=0, b=0, a=255), description="The color of the image")

    def invoke(self, context: InvocationContext) -> ImageOutput:
        image = Image.new(mode=self.mode, size=(self.width, self.height), color=self.color.tuple())

        image_dto = context.images.save(image=image)

        return ImageOutput.build(image_dto)


@invocation(
    "img_crop",
    title="Crop Image",
    tags=["image", "crop"],
    category="image",
    version="1.2.2",
)
class ImageCropInvocation(BaseInvocation, WithMetadata, WithBoard):
    """Crops an image to a specified box. The box can be outside of the image."""

    image: ImageField = InputField(description="The image to crop")
    x: int = InputField(default=0, description="The left x coordinate of the crop rectangle")
    y: int = InputField(default=0, description="The top y coordinate of the crop rectangle")
    width: int = InputField(default=512, gt=0, description="The width of the crop rectangle")
    height: int = InputField(default=512, gt=0, description="The height of the crop rectangle")

    def invoke(self, context: InvocationContext) -> ImageOutput:
        image = context.images.get_pil(self.image.image_name)

        image_crop = Image.new(mode="RGBA", size=(self.width, self.height), color=(0, 0, 0, 0))
        image_crop.paste(image, (-self.x, -self.y))

        image_dto = context.images.save(image=image_crop)

        return ImageOutput.build(image_dto)


@invocation(
    invocation_type="img_pad_crop",
    title="Center Pad or Crop Image",
    category="image",
    tags=["image", "pad", "crop"],
    version="1.0.0",
)
class CenterPadCropInvocation(BaseInvocation):
    """Pad or crop an image's sides from the center by specified pixels. Positive values are outside of the image."""

    image: ImageField = InputField(description="The image to crop")
    left: int = InputField(
        default=0,
        description="Number of pixels to pad/crop from the left (negative values crop inwards, positive values pad outwards)",
    )
    right: int = InputField(
        default=0,
        description="Number of pixels to pad/crop from the right (negative values crop inwards, positive values pad outwards)",
    )
    top: int = InputField(
        default=0,
        description="Number of pixels to pad/crop from the top (negative values crop inwards, positive values pad outwards)",
    )
    bottom: int = InputField(
        default=0,
        description="Number of pixels to pad/crop from the bottom (negative values crop inwards, positive values pad outwards)",
    )

    def invoke(self, context: InvocationContext) -> ImageOutput:
        image = context.images.get_pil(self.image.image_name)

        # Calculate and create new image dimensions
        new_width = image.width + self.right + self.left
        new_height = image.height + self.top + self.bottom
        image_crop = Image.new(mode="RGBA", size=(new_width, new_height), color=(0, 0, 0, 0))

        # Paste new image onto input
        image_crop.paste(image, (self.left, self.top))

        image_dto = context.images.save(image=image_crop)

        return ImageOutput.build(image_dto)


@invocation(
    "img_paste",
    title="Paste Image",
    tags=["image", "paste"],
    category="image",
    version="1.2.2",
)
class ImagePasteInvocation(BaseInvocation, WithMetadata, WithBoard):
    """Pastes an image into another image."""

    base_image: ImageField = InputField(description="The base image")
    image: ImageField = InputField(description="The image to paste")
    mask: Optional[ImageField] = InputField(
        default=None,
        description="The mask to use when pasting",
    )
    x: int = InputField(default=0, description="The left x coordinate at which to paste the image")
    y: int = InputField(default=0, description="The top y coordinate at which to paste the image")
    crop: bool = InputField(default=False, description="Crop to base image dimensions")

    def invoke(self, context: InvocationContext) -> ImageOutput:
        base_image = context.images.get_pil(self.base_image.image_name, mode="RGBA")
        image = context.images.get_pil(self.image.image_name, mode="RGBA")
        mask = None
        if self.mask is not None:
            mask = context.images.get_pil(self.mask.image_name, mode="L")
            mask = ImageOps.invert(mask)
        # TODO: probably shouldn't invert mask here... should user be required to do it?

        min_x = min(0, self.x)
        min_y = min(0, self.y)
        max_x = max(base_image.width, image.width + self.x)
        max_y = max(base_image.height, image.height + self.y)

        new_image = Image.new(mode="RGBA", size=(max_x - min_x, max_y - min_y), color=(0, 0, 0, 0))
        new_image.paste(base_image, (abs(min_x), abs(min_y)))

        # Create a temporary image to paste the image with transparency
        temp_image = Image.new("RGBA", new_image.size)
        temp_image.paste(image, (max(0, self.x), max(0, self.y)), mask=mask)
        new_image = Image.alpha_composite(new_image, temp_image)

        if self.crop:
            base_w, base_h = base_image.size
            new_image = new_image.crop((abs(min_x), abs(min_y), abs(min_x) + base_w, abs(min_y) + base_h))

        image_dto = context.images.save(image=new_image)

        return ImageOutput.build(image_dto)


@invocation(
    "tomask",
    title="Mask from Alpha",
    tags=["image", "mask"],
    category="image",
    version="1.2.2",
)
class MaskFromAlphaInvocation(BaseInvocation, WithMetadata, WithBoard):
    """Extracts the alpha channel of an image as a mask."""

    image: ImageField = InputField(description="The image to create the mask from")
    invert: bool = InputField(default=False, description="Whether or not to invert the mask")

    def invoke(self, context: InvocationContext) -> ImageOutput:
        image = context.images.get_pil(self.image.image_name)

        image_mask = image.split()[-1]
        if self.invert:
            image_mask = ImageOps.invert(image_mask)

        image_dto = context.images.save(image=image_mask, image_category=ImageCategory.MASK)

        return ImageOutput.build(image_dto)


@invocation(
    "img_mul",
    title="Multiply Images",
    tags=["image", "multiply"],
    category="image",
    version="1.2.2",
)
class ImageMultiplyInvocation(BaseInvocation, WithMetadata, WithBoard):
    """Multiplies two images together using `PIL.ImageChops.multiply()`."""

    image1: ImageField = InputField(description="The first image to multiply")
    image2: ImageField = InputField(description="The second image to multiply")

    def invoke(self, context: InvocationContext) -> ImageOutput:
        image1 = context.images.get_pil(self.image1.image_name)
        image2 = context.images.get_pil(self.image2.image_name)

        multiply_image = ImageChops.multiply(image1, image2)

        image_dto = context.images.save(image=multiply_image)

        return ImageOutput.build(image_dto)


IMAGE_CHANNELS = Literal["A", "R", "G", "B"]


@invocation(
    "img_chan",
    title="Extract Image Channel",
    tags=["image", "channel"],
    category="image",
    version="1.2.2",
)
class ImageChannelInvocation(BaseInvocation, WithMetadata, WithBoard):
    """Gets a channel from an image."""

    image: ImageField = InputField(description="The image to get the channel from")
    channel: IMAGE_CHANNELS = InputField(default="A", description="The channel to get")

    def invoke(self, context: InvocationContext) -> ImageOutput:
        image = context.images.get_pil(self.image.image_name)

        channel_image = image.getchannel(self.channel)

        image_dto = context.images.save(image=channel_image)

        return ImageOutput.build(image_dto)


@invocation(
    "img_conv",
    title="Convert Image Mode",
    tags=["image", "convert"],
    category="image",
    version="1.2.2",
)
class ImageConvertInvocation(BaseInvocation, WithMetadata, WithBoard):
    """Converts an image to a different mode."""

    image: ImageField = InputField(description="The image to convert")
    mode: IMAGE_MODES = InputField(default="L", description="The mode to convert to")

    def invoke(self, context: InvocationContext) -> ImageOutput:
        image = context.images.get_pil(self.image.image_name)

        converted_image = image.convert(self.mode)

        image_dto = context.images.save(image=converted_image)

        return ImageOutput.build(image_dto)


@invocation(
    "img_blur",
    title="Blur Image",
    tags=["image", "blur"],
    category="image",
    version="1.2.2",
)
class ImageBlurInvocation(BaseInvocation, WithMetadata, WithBoard):
    """Blurs an image"""

    image: ImageField = InputField(description="The image to blur")
    radius: float = InputField(default=8.0, ge=0, description="The blur radius")
    # Metadata
    blur_type: Literal["gaussian", "box"] = InputField(default="gaussian", description="The type of blur")

    def invoke(self, context: InvocationContext) -> ImageOutput:
        image = context.images.get_pil(self.image.image_name, mode="RGBA")

        # Split the image into RGBA channels
        r, g, b, a = image.split()

        # Premultiply RGB channels by alpha
        premultiplied_image = ImageChops.multiply(image, a.convert("RGBA"))
        premultiplied_image.putalpha(a)

        # Apply the blur
        blur = (
            ImageFilter.GaussianBlur(self.radius) if self.blur_type == "gaussian" else ImageFilter.BoxBlur(self.radius)
        )
        blurred_image = premultiplied_image.filter(blur)

        # Split the blurred image into RGBA channels
        r, g, b, a_orig = blurred_image.split()

        # Convert to float using NumPy. float 32/64 division are much faster than float 16
        r = numpy.array(r, dtype=numpy.float32)
        g = numpy.array(g, dtype=numpy.float32)
        b = numpy.array(b, dtype=numpy.float32)
        a = numpy.array(a_orig, dtype=numpy.float32) / 255.0  # Normalize alpha to [0, 1]

        # Unpremultiply RGB channels by alpha
        r /= a + 1e-6  # Add a small epsilon to avoid division by zero
        g /= a + 1e-6
        b /= a + 1e-6

        # Convert back to PIL images
        r = Image.fromarray(numpy.uint8(numpy.clip(r, 0, 255)))
        g = Image.fromarray(numpy.uint8(numpy.clip(g, 0, 255)))
        b = Image.fromarray(numpy.uint8(numpy.clip(b, 0, 255)))

        # Merge back into a single image
        result_image = Image.merge("RGBA", (r, g, b, a_orig))

        image_dto = context.images.save(image=result_image)

        return ImageOutput.build(image_dto)


@invocation(
    "unsharp_mask",
    title="Unsharp Mask",
    tags=["image", "unsharp_mask"],
    category="image",
    version="1.2.2",
)
class UnsharpMaskInvocation(BaseInvocation, WithMetadata, WithBoard):
    """Applies an unsharp mask filter to an image"""

    image: ImageField = InputField(description="The image to use")
    radius: float = InputField(gt=0, description="Unsharp mask radius", default=2)
    strength: float = InputField(ge=0, description="Unsharp mask strength", default=50)

    def pil_from_array(self, arr):
        return Image.fromarray((arr * 255).astype("uint8"))

    def array_from_pil(self, img):
        return numpy.array(img) / 255

    def invoke(self, context: InvocationContext) -> ImageOutput:
        image = context.images.get_pil(self.image.image_name)
        mode = image.mode

        alpha_channel = image.getchannel("A") if mode == "RGBA" else None
        image = image.convert("RGB")
        image_blurred = self.array_from_pil(image.filter(ImageFilter.GaussianBlur(radius=self.radius)))

        image = self.array_from_pil(image)
        image += (image - image_blurred) * (self.strength / 100.0)
        image = numpy.clip(image, 0, 1)
        image = self.pil_from_array(image)

        image = image.convert(mode)

        # Make the image RGBA if we had a source alpha channel
        if alpha_channel is not None:
            image.putalpha(alpha_channel)

        image_dto = context.images.save(image=image)

        return ImageOutput(
            image=ImageField(image_name=image_dto.image_name),
            width=image.width,
            height=image.height,
        )


PIL_RESAMPLING_MODES = Literal[
    "nearest",
    "box",
    "bilinear",
    "hamming",
    "bicubic",
    "lanczos",
]


PIL_RESAMPLING_MAP = {
    "nearest": Image.Resampling.NEAREST,
    "box": Image.Resampling.BOX,
    "bilinear": Image.Resampling.BILINEAR,
    "hamming": Image.Resampling.HAMMING,
    "bicubic": Image.Resampling.BICUBIC,
    "lanczos": Image.Resampling.LANCZOS,
}


@invocation(
    "img_resize",
    title="Resize Image",
    tags=["image", "resize"],
    category="image",
    version="1.2.2",
)
class ImageResizeInvocation(BaseInvocation, WithMetadata, WithBoard):
    """Resizes an image to specific dimensions"""

    image: ImageField = InputField(description="The image to resize")
    width: int = InputField(default=512, gt=0, description="The width to resize to (px)")
    height: int = InputField(default=512, gt=0, description="The height to resize to (px)")
    resample_mode: PIL_RESAMPLING_MODES = InputField(default="bicubic", description="The resampling mode")

    def invoke(self, context: InvocationContext) -> ImageOutput:
        image = context.images.get_pil(self.image.image_name)

        resample_mode = PIL_RESAMPLING_MAP[self.resample_mode]

        resize_image = image.resize(
            (self.width, self.height),
            resample=resample_mode,
        )

        image_dto = context.images.save(image=resize_image)

        return ImageOutput.build(image_dto)


@invocation(
    "img_scale",
    title="Scale Image",
    tags=["image", "scale"],
    category="image",
    version="1.2.2",
)
class ImageScaleInvocation(BaseInvocation, WithMetadata, WithBoard):
    """Scales an image by a factor"""

    image: ImageField = InputField(description="The image to scale")
    scale_factor: float = InputField(
        default=2.0,
        gt=0,
        description="The factor by which to scale the image",
    )
    resample_mode: PIL_RESAMPLING_MODES = InputField(default="bicubic", description="The resampling mode")

    def invoke(self, context: InvocationContext) -> ImageOutput:
        image = context.images.get_pil(self.image.image_name)

        resample_mode = PIL_RESAMPLING_MAP[self.resample_mode]
        width = int(image.width * self.scale_factor)
        height = int(image.height * self.scale_factor)

        resize_image = image.resize(
            (width, height),
            resample=resample_mode,
        )

        image_dto = context.images.save(image=resize_image)

        return ImageOutput.build(image_dto)


@invocation(
    "img_lerp",
    title="Lerp Image",
    tags=["image", "lerp"],
    category="image",
    version="1.2.2",
)
class ImageLerpInvocation(BaseInvocation, WithMetadata, WithBoard):
    """Linear interpolation of all pixels of an image"""

    image: ImageField = InputField(description="The image to lerp")
    min: int = InputField(default=0, ge=0, le=255, description="The minimum output value")
    max: int = InputField(default=255, ge=0, le=255, description="The maximum output value")

    def invoke(self, context: InvocationContext) -> ImageOutput:
        image = context.images.get_pil(self.image.image_name)

        image_arr = numpy.asarray(image, dtype=numpy.float32) / 255
        image_arr = image_arr * (self.max - self.min) + self.min

        lerp_image = Image.fromarray(numpy.uint8(image_arr))

        image_dto = context.images.save(image=lerp_image)

        return ImageOutput.build(image_dto)


@invocation(
    "img_ilerp",
    title="Inverse Lerp Image",
    tags=["image", "ilerp"],
    category="image",
    version="1.2.2",
)
class ImageInverseLerpInvocation(BaseInvocation, WithMetadata, WithBoard):
    """Inverse linear interpolation of all pixels of an image"""

    image: ImageField = InputField(description="The image to lerp")
    min: int = InputField(default=0, ge=0, le=255, description="The minimum input value")
    max: int = InputField(default=255, ge=0, le=255, description="The maximum input value")

    def invoke(self, context: InvocationContext) -> ImageOutput:
        image = context.images.get_pil(self.image.image_name)

        image_arr = numpy.asarray(image, dtype=numpy.float32)
        image_arr = numpy.minimum(numpy.maximum(image_arr - self.min, 0) / float(self.max - self.min), 1) * 255  # type: ignore [assignment]

        ilerp_image = Image.fromarray(numpy.uint8(image_arr))

        image_dto = context.images.save(image=ilerp_image)

        return ImageOutput.build(image_dto)


@invocation(
    "img_nsfw",
    title="Blur NSFW Image",
    tags=["image", "nsfw"],
    category="image",
    version="1.2.3",
)
class ImageNSFWBlurInvocation(BaseInvocation, WithMetadata, WithBoard):
    """Add blur to NSFW-flagged images"""

    image: ImageField = InputField(description="The image to check")

    def invoke(self, context: InvocationContext) -> ImageOutput:
        image = context.images.get_pil(self.image.image_name)

        logger = context.logger
        logger.debug("Running NSFW checker")
        image = SafetyChecker.blur_if_nsfw(image)

        image_dto = context.images.save(image=image)

        return ImageOutput.build(image_dto)


@invocation(
    "img_watermark",
    title="Add Invisible Watermark",
    tags=["image", "watermark"],
    category="image",
    version="1.2.2",
)
class ImageWatermarkInvocation(BaseInvocation, WithMetadata, WithBoard):
    """Add an invisible watermark to an image"""

    image: ImageField = InputField(description="The image to check")
    text: str = InputField(default="InvokeAI", description="Watermark text")

    def invoke(self, context: InvocationContext) -> ImageOutput:
        image = context.images.get_pil(self.image.image_name)
        new_image = InvisibleWatermark.add_watermark(image, self.text)
        image_dto = context.images.save(image=new_image)

        return ImageOutput.build(image_dto)


@invocation(
    "mask_edge",
    title="Mask Edge",
    tags=["image", "mask", "inpaint"],
    category="image",
    version="1.2.2",
)
class MaskEdgeInvocation(BaseInvocation, WithMetadata, WithBoard):
    """Applies an edge mask to an image"""

    image: ImageField = InputField(description="The image to apply the mask to")
    edge_size: int = InputField(description="The size of the edge")
    edge_blur: int = InputField(description="The amount of blur on the edge")
    low_threshold: int = InputField(description="First threshold for the hysteresis procedure in Canny edge detection")
    high_threshold: int = InputField(
        description="Second threshold for the hysteresis procedure in Canny edge detection"
    )

    def invoke(self, context: InvocationContext) -> ImageOutput:
        mask = context.images.get_pil(self.image.image_name).convert("L")

        npimg = numpy.asarray(mask, dtype=numpy.uint8)
        npgradient = numpy.uint8(255 * (1.0 - numpy.floor(numpy.abs(0.5 - numpy.float32(npimg) / 255.0) * 2.0)))
        npedge = cv2.Canny(npimg, threshold1=self.low_threshold, threshold2=self.high_threshold)
        npmask = npgradient + npedge
        npmask = cv2.dilate(npmask, numpy.ones((3, 3), numpy.uint8), iterations=int(self.edge_size / 2))

        new_mask = Image.fromarray(npmask)

        if self.edge_blur > 0:
            new_mask = new_mask.filter(ImageFilter.BoxBlur(self.edge_blur))

        new_mask = ImageOps.invert(new_mask)

        image_dto = context.images.save(image=new_mask, image_category=ImageCategory.MASK)

        return ImageOutput.build(image_dto)


@invocation(
    "mask_combine",
    title="Combine Masks",
    tags=["image", "mask", "multiply"],
    category="image",
    version="1.2.2",
)
class MaskCombineInvocation(BaseInvocation, WithMetadata, WithBoard):
    """Combine two masks together by multiplying them using `PIL.ImageChops.multiply()`."""

    mask1: ImageField = InputField(description="The first mask to combine")
    mask2: ImageField = InputField(description="The second image to combine")

    def invoke(self, context: InvocationContext) -> ImageOutput:
        mask1 = context.images.get_pil(self.mask1.image_name).convert("L")
        mask2 = context.images.get_pil(self.mask2.image_name).convert("L")

        combined_mask = ImageChops.multiply(mask1, mask2)

        image_dto = context.images.save(image=combined_mask, image_category=ImageCategory.MASK)

        return ImageOutput.build(image_dto)


@invocation(
    "color_correct",
    title="Color Correct",
    tags=["image", "color"],
    category="image",
    version="2.0.0",
)
class ColorCorrectInvocation(BaseInvocation, WithMetadata, WithBoard):
    """
    Matches the color histogram of a base image to a reference image, optionally
    using a mask to only color-correct certain regions of the base image.
    """

    base_image: ImageField = InputField(description="The image to color-correct")
    color_reference: ImageField = InputField(description="Reference image for color-correction")
    mask: Optional[ImageField] = InputField(default=None, description="Optional mask to limit color correction area")
    colorspace: Literal["RGB", "YCbCr", "YCbCr-Chroma", "YCbCr-Luma"] = InputField(
        default="RGB", description="Colorspace in which to apply histogram matching", title="Color Space"
    )

    def _match_histogram_channel(self, source: numpy.ndarray, reference: numpy.ndarray) -> numpy.ndarray:
        """Match histogram of source channel to reference channel using cumulative distribution functions."""
        # Compute histograms
        source_hist, _ = numpy.histogram(source.flatten(), bins=256, range=(0, 256))
        reference_hist, _ = numpy.histogram(reference.flatten(), bins=256, range=(0, 256))

        # Compute cumulative distribution functions
        source_cdf = source_hist.cumsum()
        reference_cdf = reference_hist.cumsum()

        # Normalize CDFs (avoid division by zero)
        if source_cdf[-1] > 0:
            source_cdf = source_cdf / source_cdf[-1]
        if reference_cdf[-1] > 0:
            reference_cdf = reference_cdf / reference_cdf[-1]

        # Create lookup table using linear interpolation
        lookup_table = numpy.interp(source_cdf, reference_cdf, numpy.arange(256))

        # Apply lookup table to source image
        return lookup_table[source].astype(numpy.uint8)

    def invoke(self, context: InvocationContext) -> ImageOutput:
        # Load images as RGBA
        base_image = context.images.get_pil(self.base_image.image_name, "RGBA")

        # Store original alpha channel
        original_alpha = base_image.getchannel("A")

        # Convert to working colorspace
        if self.colorspace == "RGB":
            base_array = numpy.asarray(base_image.convert("RGB"), dtype=numpy.uint8)
            ref_rgb = context.images.get_pil(self.color_reference.image_name, "RGB")
            ref_array = numpy.asarray(ref_rgb, dtype=numpy.uint8)
            channels_to_match = [0, 1, 2]  # R, G, B
        else:
            # Convert to YCbCr colorspace
            base_ycbcr = base_image.convert("YCbCr")
            ref_ycbcr = context.images.get_pil(self.color_reference.image_name, "YCbCr")

            base_array = numpy.asarray(base_ycbcr, dtype=numpy.uint8)
            ref_array = numpy.asarray(ref_ycbcr, dtype=numpy.uint8)

            # Determine which channels to match based on mode
            if self.colorspace == "YCbCr":
                channels_to_match = [0, 1, 2]  # Y, Cb, Cr
            elif self.colorspace == "YCbCr-Chroma":
                channels_to_match = [1, 2]  # Cb, Cr only
            else:  # YCbCr-Luma
                channels_to_match = [0]  # Y only

        # Apply histogram matching to selected channels
        corrected_array = base_array.copy()
        for channel_idx in channels_to_match:
            corrected_array[:, :, channel_idx] = self._match_histogram_channel(
                base_array[:, :, channel_idx], ref_array[:, :, channel_idx]
            )

        # Convert back to RGB if we were in YCbCr
        if self.colorspace != "RGB":
            corrected_image = Image.fromarray(corrected_array, mode="YCbCr").convert("RGB")
        else:
            corrected_image = Image.fromarray(corrected_array, mode="RGB")

        # Apply mask if provided (white = original, black = result)
        if self.mask is not None:
            # Load mask as grayscale
            mask_image = context.images.get_pil(self.mask.image_name, "L")
            # Start with corrected image, paste base image where mask is white
            result = corrected_image.copy()
            if mask_image.size != result.size:
                raise ValueError("Mask size must match base image size.")
            else:
                result.paste(base_image.convert("RGB"), mask=mask_image)
        else:
            result = corrected_image

        # Convert to RGBA and restore original alpha
        result = result.convert("RGBA")
        result.putalpha(original_alpha)

        # Save and return
        image_dto = context.images.save(image=result)
        return ImageOutput.build(image_dto)


@invocation(
    "img_hue_adjust",
    title="Adjust Image Hue",
    tags=["image", "hue"],
    category="image",
    version="1.2.2",
)
class ImageHueAdjustmentInvocation(BaseInvocation, WithMetadata, WithBoard):
    """Adjusts the Hue of an image."""

    image: ImageField = InputField(description="The image to adjust")
    hue: int = InputField(default=0, description="The degrees by which to rotate the hue, 0-360")

    def invoke(self, context: InvocationContext) -> ImageOutput:
        pil_image = context.images.get_pil(self.image.image_name)

        # Convert image to HSV color space
        hsv_image = numpy.array(pil_image.convert("HSV"))

        # Convert hue from 0..360 to 0..256
        hue = int(256 * ((self.hue % 360) / 360))

        # Increment each hue and wrap around at 255
        hsv_image[:, :, 0] = (hsv_image[:, :, 0] + hue) % 256

        # Convert back to PIL format and to original color mode
        pil_image = Image.fromarray(hsv_image, mode="HSV").convert("RGBA")

        image_dto = context.images.save(image=pil_image)

        return ImageOutput.build(image_dto)


COLOR_CHANNELS = Literal[
    "Red (RGBA)",
    "Green (RGBA)",
    "Blue (RGBA)",
    "Alpha (RGBA)",
    "Cyan (CMYK)",
    "Magenta (CMYK)",
    "Yellow (CMYK)",
    "Black (CMYK)",
    "Hue (HSV)",
    "Saturation (HSV)",
    "Value (HSV)",
    "Luminosity (LAB)",
    "A (LAB)",
    "B (LAB)",
    "Y (YCbCr)",
    "Cb (YCbCr)",
    "Cr (YCbCr)",
]

CHANNEL_FORMATS = {
    "Red (RGBA)": ("RGBA", 0),
    "Green (RGBA)": ("RGBA", 1),
    "Blue (RGBA)": ("RGBA", 2),
    "Alpha (RGBA)": ("RGBA", 3),
    "Cyan (CMYK)": ("CMYK", 0),
    "Magenta (CMYK)": ("CMYK", 1),
    "Yellow (CMYK)": ("CMYK", 2),
    "Black (CMYK)": ("CMYK", 3),
    "Hue (HSV)": ("HSV", 0),
    "Saturation (HSV)": ("HSV", 1),
    "Value (HSV)": ("HSV", 2),
    "Luminosity (LAB)": ("LAB", 0),
    "A (LAB)": ("LAB", 1),
    "B (LAB)": ("LAB", 2),
    "Y (YCbCr)": ("YCbCr", 0),
    "Cb (YCbCr)": ("YCbCr", 1),
    "Cr (YCbCr)": ("YCbCr", 2),
}


@invocation(
    "img_channel_offset",
    title="Offset Image Channel",
    tags=[
        "image",
        "offset",
        "red",
        "green",
        "blue",
        "alpha",
        "cyan",
        "magenta",
        "yellow",
        "black",
        "hue",
        "saturation",
        "luminosity",
        "value",
    ],
    category="image",
    version="1.2.3",
)
class ImageChannelOffsetInvocation(BaseInvocation, WithMetadata, WithBoard):
    """Add or subtract a value from a specific color channel of an image."""

    image: ImageField = InputField(description="The image to adjust")
    channel: COLOR_CHANNELS = InputField(description="Which channel to adjust")
    offset: int = InputField(default=0, ge=-255, le=255, description="The amount to adjust the channel by")

    def invoke(self, context: InvocationContext) -> ImageOutput:
        image = context.images.get_pil(self.image.image_name, "RGBA")

        # extract the channel and mode from the input and reference tuple
        mode = CHANNEL_FORMATS[self.channel][0]
        channel_number = CHANNEL_FORMATS[self.channel][1]

        # Convert PIL image to new format
        converted_image = numpy.array(image.convert(mode)).astype(int)
        image_channel = converted_image[:, :, channel_number]

        if self.channel == "Hue (HSV)":
            # loop around the values because hue is special
            image_channel = (image_channel + self.offset) % 256
        else:
            # Adjust the value, clipping to 0..255
            image_channel = numpy.clip(image_channel + self.offset, 0, 255)

        # Put the channel back into the image
        converted_image[:, :, channel_number] = image_channel

        # Convert back to RGBA format and output
        pil_image = Image.fromarray(converted_image.astype(numpy.uint8), mode=mode).convert("RGBA")

        # restore the alpha channel
        if self.channel != "Alpha (RGBA)":
            pil_image.putalpha(image.getchannel("A"))

        image_dto = context.images.save(image=pil_image)

        return ImageOutput.build(image_dto)


@invocation(
    "img_channel_multiply",
    title="Multiply Image Channel",
    tags=[
        "image",
        "invert",
        "scale",
        "multiply",
        "red",
        "green",
        "blue",
        "alpha",
        "cyan",
        "magenta",
        "yellow",
        "black",
        "hue",
        "saturation",
        "luminosity",
        "value",
    ],
    category="image",
    version="1.2.3",
)
class ImageChannelMultiplyInvocation(BaseInvocation, WithMetadata, WithBoard):
    """Scale a specific color channel of an image."""

    image: ImageField = InputField(description="The image to adjust")
    channel: COLOR_CHANNELS = InputField(description="Which channel to adjust")
    scale: float = InputField(default=1.0, ge=0.0, description="The amount to scale the channel by.")
    invert_channel: bool = InputField(default=False, description="Invert the channel after scaling")

    def invoke(self, context: InvocationContext) -> ImageOutput:
        image = context.images.get_pil(self.image.image_name, "RGBA")

        # extract the channel and mode from the input and reference tuple
        mode = CHANNEL_FORMATS[self.channel][0]
        channel_number = CHANNEL_FORMATS[self.channel][1]

        # Convert PIL image to new format
        converted_image = numpy.array(image.convert(mode)).astype(float)
        image_channel = converted_image[:, :, channel_number]

        # Adjust the value, clipping to 0..255
        image_channel = numpy.clip(image_channel * self.scale, 0, 255)

        # Invert the channel if requested
        if self.invert_channel:
            image_channel = 255 - image_channel

        # Put the channel back into the image
        converted_image[:, :, channel_number] = image_channel

        # Convert back to RGBA format and output
        pil_image = Image.fromarray(converted_image.astype(numpy.uint8), mode=mode).convert("RGBA")

        # restore the alpha channel
        if self.channel != "Alpha (RGBA)":
            pil_image.putalpha(image.getchannel("A"))

        image_dto = context.images.save(image=pil_image)

        return ImageOutput.build(image_dto)


@invocation(
    "save_image",
    title="Save Image",
    tags=["primitives", "image"],
    category="primitives",
    version="1.2.2",
    use_cache=False,
)
class SaveImageInvocation(BaseInvocation, WithMetadata, WithBoard):
    """Saves an image. Unlike an image primitive, this invocation stores a copy of the image."""

    image: ImageField = InputField(description=FieldDescriptions.image)

    def invoke(self, context: InvocationContext) -> ImageOutput:
        image = context.images.get_pil(self.image.image_name)

        image_dto = context.images.save(image=image)

        return ImageOutput.build(image_dto)


@invocation(
    "canvas_paste_back",
    title="Canvas Paste Back",
    tags=["image", "combine"],
    category="image",
    version="1.0.1",
)
class CanvasPasteBackInvocation(BaseInvocation, WithMetadata, WithBoard):
    """Combines two images by using the mask provided. Intended for use on the Unified Canvas."""

    source_image: ImageField = InputField(description="The source image")
    target_image: ImageField = InputField(description="The target image")
    mask: ImageField = InputField(
        description="The mask to use when pasting",
    )
    mask_blur: int = InputField(default=0, ge=0, description="The amount to blur the mask by")

    def _prepare_mask(self, mask: Image.Image) -> Image.Image:
        mask_array = numpy.array(mask)
        kernel = numpy.ones((self.mask_blur, self.mask_blur), numpy.uint8)
        dilated_mask_array = cv2.erode(mask_array, kernel, iterations=3)
        dilated_mask = Image.fromarray(dilated_mask_array)
        if self.mask_blur > 0:
            mask = dilated_mask.filter(ImageFilter.GaussianBlur(self.mask_blur))
        return ImageOps.invert(mask.convert("L"))

    def invoke(self, context: InvocationContext) -> ImageOutput:
        source_image = context.images.get_pil(self.source_image.image_name)
        target_image = context.images.get_pil(self.target_image.image_name)
        mask = self._prepare_mask(context.images.get_pil(self.mask.image_name))

        source_image.paste(target_image, (0, 0), mask)

        image_dto = context.images.save(image=source_image)
        return ImageOutput.build(image_dto)


@invocation(
    "mask_from_id",
    title="Mask from Segmented Image",
    tags=["image", "mask", "id"],
    category="image",
    version="1.0.1",
)
class MaskFromIDInvocation(BaseInvocation, WithMetadata, WithBoard):
    """Generate a mask for a particular color in an ID Map"""

    image: ImageField = InputField(description="The image to create the mask from")
    color: ColorField = InputField(description="ID color to mask")
    threshold: int = InputField(default=100, description="Threshold for color detection")
    invert: bool = InputField(default=False, description="Whether or not to invert the mask")

    def invoke(self, context: InvocationContext) -> ImageOutput:
        image = context.images.get_pil(self.image.image_name, mode="RGBA")

        np_color = numpy.array(self.color.tuple())

        # Maybe there's a faster way to calculate this distance but I can't think of any right now.
        color_distance = numpy.linalg.norm(image - np_color, axis=-1)

        # Create a mask based on the threshold and the distance calculated above
        binary_mask = (color_distance < self.threshold).astype(numpy.uint8) * 255

        # Convert the mask back to PIL
        binary_mask_pil = Image.fromarray(binary_mask)

        if self.invert:
            binary_mask_pil = ImageOps.invert(binary_mask_pil)

        image_dto = context.images.save(image=binary_mask_pil, image_category=ImageCategory.MASK)

        return ImageOutput.build(image_dto)


@invocation(
    "canvas_v2_mask_and_crop",
    title="Canvas V2 Mask and Crop",
    tags=["image", "mask", "id"],
    category="image",
    version="1.0.0",
    classification=Classification.Deprecated,
)
class CanvasV2MaskAndCropInvocation(BaseInvocation, WithMetadata, WithBoard):
    """Handles Canvas V2 image output masking and cropping"""

    source_image: ImageField | None = InputField(
        default=None,
        description="The source image onto which the masked generated image is pasted. If omitted, the masked generated image is returned with transparency.",
    )
    generated_image: ImageField = InputField(description="The image to apply the mask to")
    mask: ImageField = InputField(description="The mask to apply")
    mask_blur: int = InputField(default=0, ge=0, description="The amount to blur the mask by")

    def _prepare_mask(self, mask: Image.Image) -> Image.Image:
        mask_array = numpy.array(mask)
        kernel = numpy.ones((self.mask_blur, self.mask_blur), numpy.uint8)
        dilated_mask_array = cv2.erode(mask_array, kernel, iterations=3)
        dilated_mask = Image.fromarray(dilated_mask_array)
        if self.mask_blur > 0:
            mask = dilated_mask.filter(ImageFilter.GaussianBlur(self.mask_blur))
        return ImageOps.invert(mask.convert("L"))

    def invoke(self, context: InvocationContext) -> ImageOutput:
        mask = self._prepare_mask(context.images.get_pil(self.mask.image_name))

        if self.source_image:
            generated_image = context.images.get_pil(self.generated_image.image_name)
            source_image = context.images.get_pil(self.source_image.image_name)
            source_image.paste(generated_image, (0, 0), mask)
            image_dto = context.images.save(image=source_image)
        else:
            generated_image = context.images.get_pil(self.generated_image.image_name)
            generated_image.putalpha(mask)
            image_dto = context.images.save(image=generated_image)

        return ImageOutput.build(image_dto)


@invocation(
    "expand_mask_with_fade", title="Expand Mask with Fade", tags=["image", "mask"], category="image", version="1.0.1"
)
class ExpandMaskWithFadeInvocation(BaseInvocation, WithMetadata, WithBoard):
    """Expands a mask with a fade effect. The mask uses black to indicate areas to keep from the generated image and white for areas to discard.
    The mask is thresholded to create a binary mask, and then a distance transform is applied to create a fade effect.
    The fade size is specified in pixels, and the mask is expanded by that amount. The result is a mask with a smooth transition from black to white.
    If the fade size is 0, the mask is returned as-is.
    """

    mask: ImageField = InputField(description="The mask to expand")
    threshold: int = InputField(default=0, ge=0, le=255, description="The threshold for the binary mask (0-255)")
    fade_size_px: int = InputField(default=32, ge=0, description="The size of the fade in pixels")

    def invoke(self, context: InvocationContext) -> ImageOutput:
        pil_mask = context.images.get_pil(self.mask.image_name, mode="L")

        if self.fade_size_px == 0:
            # If the fade size is 0, just return the mask as-is.
            image_dto = context.images.save(image=pil_mask, image_category=ImageCategory.MASK)
            return ImageOutput.build(image_dto)

        np_mask = numpy.array(pil_mask)

        # Threshold the mask to create a binary mask - 0 for black, 255 for white
        # If we don't threshold we can get some weird artifacts
        np_mask = numpy.where(np_mask > self.threshold, 255, 0).astype(numpy.uint8)

        # Create a mask for the black region (1 where black, 0 otherwise)
        black_mask = (np_mask == 0).astype(numpy.uint8)

        # Invert the black region
        bg_mask = 1 - black_mask

        # Create a distance transform of the inverted mask
        dist = cv2.distanceTransform(bg_mask, cv2.DIST_L2, 5)

        # Normalize distances so that pixels <fade_size_px become a linear gradient (0 to 1)
        d_norm = numpy.clip(dist / self.fade_size_px, 0, 1)

        # Control points: x values (normalized distance) and corresponding fade pct y values.

        # There are some magic numbers here that are used to create a smooth transition:
        # - The first point is at 0% of fade size from edge of mask (meaning the edge of the mask), and is 0% fade (black)
        # - The second point is 1px from the edge of the mask and also has 0% fade, effectively expanding the mask
        #   by 1px. This fixes an issue where artifacts can occur at the edge of the mask
        # - The third point is at 20% of the fade size from the edge of the mask and has 20% fade
        # - The fourth point is at 80% of the fade size from the edge of the mask and has 90% fade
        # - The last point is at 100% of the fade size from the edge of the mask and has 100% fade (white)

        # x values: 0 = mask edge, 1 = fade_size_px from edge
        x_control = numpy.array([0.0, 1.0 / self.fade_size_px, 0.2, 0.8, 1.0])
        # y values: 0 = black, 1 = white
        y_control = numpy.array([0.0, 0.0, 0.2, 0.9, 1.0])

        # Fit a cubic polynomial that smoothly passes through the control points
        coeffs = numpy.polyfit(x_control, y_control, 3)
        poly = numpy.poly1d(coeffs)

        # Evaluate the polynomial
        feather = poly(d_norm)

        # The polynomial fit isn't perfect. Points beyond the fade distance are likely to be slightly less than 1.0,
        # even though the control points indicate that they should be exactly 1.0. This is due to the nature of the
        # polynomial fit, which is a best approximation of the control points but not an exact match.

        # When this occurs, the area outside the mask and fade-out will not be 100% transparent. For example, it may
        # have an alpha value of 1 instead of 0. So we must force pixels at or beyond the fade distance to exactly 1.0.

        # Force pixels at or beyond the fade distance to exactly 1.0
        feather = numpy.where(d_norm >= 1.0, 1.0, feather)

        # Clip any other values to ensure they're in the valid range [0,1]
        feather = numpy.clip(feather, 0, 1)

        # Build final image.
        np_result = numpy.where(black_mask == 1, 0, (feather * 255).astype(numpy.uint8))

        # Convert back to PIL, grayscale
        pil_result = Image.fromarray(np_result.astype(numpy.uint8), mode="L")

        image_dto = context.images.save(image=pil_result, image_category=ImageCategory.MASK)

        return ImageOutput.build(image_dto)


@invocation(
    "apply_mask_to_image",
    title="Apply Mask to Image",
    tags=["image", "mask", "blend"],
    category="image",
    version="1.0.0",
)
class ApplyMaskToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
    """
    Extracts a region from a generated image using a mask and blends it seamlessly onto a source image.
    The mask uses black to indicate areas to keep from the generated image and white for areas to discard.
    """

    image: ImageField = InputField(description="The image from which to extract the masked region")
    mask: ImageField = InputField(description="The mask defining the region (black=keep, white=discard)")
    invert_mask: bool = InputField(
        default=False,
        description="Whether to invert the mask before applying it",
    )

    def invoke(self, context: InvocationContext) -> ImageOutput:
        # Load images
        image = context.images.get_pil(self.image.image_name, mode="RGBA")
        mask = context.images.get_pil(self.mask.image_name, mode="L")

        if self.invert_mask:
            # Invert the mask if requested
            mask = ImageOps.invert(mask.copy())

        # Combine the mask as the alpha channel of the image
        r, g, b, _ = image.split()  # Split the image into RGB and alpha channels
        result_image = Image.merge("RGBA", (r, g, b, mask))  # Use the mask as the new alpha channel

        # Save the resulting image
        image_dto = context.images.save(image=result_image)

        return ImageOutput.build(image_dto)


@invocation(
    "img_noise",
    title="Add Image Noise",
    tags=["image", "noise"],
    category="image",
    version="1.1.0",
)
class ImageNoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
    """Add noise to an image"""

    image: ImageField = InputField(description="The image to add noise to")
    mask: Optional[ImageField] = InputField(
        default=None, description="Optional mask determining where to apply noise (black=noise, white=no noise)"
    )
    seed: int = InputField(
        default=0,
        ge=0,
        le=SEED_MAX,
        description=FieldDescriptions.seed,
    )
    noise_type: Literal["gaussian", "salt_and_pepper"] = InputField(
        default="gaussian",
        description="The type of noise to add",
    )
    amount: float = InputField(default=0.1, ge=0, le=1, description="The amount of noise to add")
    noise_color: bool = InputField(default=True, description="Whether to add colored noise")
    size: int = InputField(default=1, ge=1, description="The size of the noise points")

    def invoke(self, context: InvocationContext) -> ImageOutput:
        image = context.images.get_pil(self.image.image_name, mode="RGBA")

        # Save out the alpha channel
        alpha = image.getchannel("A")

        # Set the seed for numpy random
        rs = numpy.random.RandomState(numpy.random.MT19937(numpy.random.SeedSequence(self.seed)))

        if self.noise_type == "gaussian":
            if self.noise_color:
                noise = rs.normal(0, 1, (image.height // self.size, image.width // self.size, 3)) * 255
            else:
                noise = rs.normal(0, 1, (image.height // self.size, image.width // self.size)) * 255
                noise = numpy.stack([noise] * 3, axis=-1)
        elif self.noise_type == "salt_and_pepper":
            if self.noise_color:
                noise = rs.choice(
                    [0, 255], (image.height // self.size, image.width // self.size, 3), p=[1 - self.amount, self.amount]
                )
            else:
                noise = rs.choice(
                    [0, 255], (image.height // self.size, image.width // self.size), p=[1 - self.amount, self.amount]
                )
                noise = numpy.stack([noise] * 3, axis=-1)

        noise = Image.fromarray(noise.astype(numpy.uint8), mode="RGB").resize(
            (image.width, image.height), Image.Resampling.NEAREST
        )

        # Create a noisy version of the input image
        noisy_image = Image.blend(image.convert("RGB"), noise, self.amount).convert("RGBA")

        # Apply mask if provided
        if self.mask is not None:
            mask_image = context.images.get_pil(self.mask.image_name, mode="L")

            if mask_image.size != image.size:
                mask_image = mask_image.resize(image.size, Image.Resampling.LANCZOS)

            result_image = image.copy()
            mask_image = ImageOps.invert(mask_image)
            result_image.paste(noisy_image, (0, 0), mask=mask_image)
        else:
            result_image = noisy_image

        # Paste back the alpha channel from the original image
        result_image.putalpha(alpha)

        image_dto = context.images.save(image=result_image)

        return ImageOutput.build(image_dto)


@invocation(
    "crop_image_to_bounding_box",
    title="Crop Image to Bounding Box",
    category="image",
    version="1.0.0",
    tags=["image", "crop"],
)
class CropImageToBoundingBoxInvocation(BaseInvocation, WithMetadata, WithBoard):
    """Crop an image to the given bounding box. If the bounding box is omitted, the image is cropped to the non-transparent pixels."""

    image: ImageField = InputField(description="The image to crop")
    bounding_box: BoundingBoxField | None = InputField(
        default=None, description="The bounding box to crop the image to"
    )

    def invoke(self, context: InvocationContext) -> ImageOutput:
        image = context.images.get_pil(self.image.image_name)

        bounding_box = self.bounding_box.tuple() if self.bounding_box is not None else image.getbbox()

        cropped_image = image.crop(bounding_box)

        image_dto = context.images.save(image=cropped_image)
        return ImageOutput.build(image_dto)


@invocation(
    "paste_image_into_bounding_box",
    title="Paste Image into Bounding Box",
    category="image",
    version="1.0.0",
    tags=["image", "crop"],
)
class PasteImageIntoBoundingBoxInvocation(BaseInvocation, WithMetadata, WithBoard):
    """Paste the source image into the target image at the given bounding box.

    The source image must be the same size as the bounding box, and the bounding box must fit within the target image."""

    source_image: ImageField = InputField(description="The image to paste")
    target_image: ImageField = InputField(description="The image to paste into")
    bounding_box: BoundingBoxField = InputField(description="The bounding box to paste the image into")

    def invoke(self, context: InvocationContext) -> ImageOutput:
        source_image = context.images.get_pil(self.source_image.image_name, mode="RGBA")
        target_image = context.images.get_pil(self.target_image.image_name, mode="RGBA")

        bounding_box = self.bounding_box.tuple()

        target_image.paste(source_image, bounding_box, source_image)

        image_dto = context.images.save(image=target_image)
        return ImageOutput.build(image_dto)


@invocation(
    "flux_kontext_image_prep",
    title="FLUX Kontext Image Prep",
    tags=["image", "concatenate", "flux", "kontext"],
    category="image",
    version="1.0.0",
)
class FluxKontextConcatenateImagesInvocation(BaseInvocation, WithMetadata, WithBoard):
    """Prepares an image or images for use with FLUX Kontext. The first/single image is resized to the nearest
    preferred Kontext resolution. All other images are concatenated horizontally, maintaining their aspect ratio."""

    images: list[ImageField] = InputField(
        description="The images to concatenate",
        min_length=1,
        max_length=10,
    )

    use_preferred_resolution: bool = InputField(
        default=True, description="Use FLUX preferred resolutions for the first image"
    )

    def invoke(self, context: InvocationContext) -> ImageOutput:
        from invokeai.backend.flux.util import PREFERED_KONTEXT_RESOLUTIONS

        # Step 1: Load all images
        pil_images = []
        for image_field in self.images:
            image = context.images.get_pil(image_field.image_name, mode="RGBA")
            pil_images.append(image)

        # Step 2: Determine target resolution for the first image
        first_image = pil_images[0]
        width, height = first_image.size

        if self.use_preferred_resolution:
            aspect_ratio = width / height

            # Find the closest preferred resolution for the first image
            _, target_width, target_height = min(
                ((abs(aspect_ratio - w / h), w, h) for w, h in PREFERED_KONTEXT_RESOLUTIONS), key=lambda x: x[0]
            )

            # Apply BFL's scaling formula
            scaled_height = 2 * int(target_height / 16)
            final_height = 8 * scaled_height  # This will be consistent for all images
            scaled_width = 2 * int(target_width / 16)
            first_width = 8 * scaled_width
        else:
            # Use original dimensions of first image, ensuring divisibility by 16
            final_height = 16 * (height // 16)
            first_width = 16 * (width // 16)
            # Ensure minimum dimensions
            if final_height < 16:
                final_height = 16
            if first_width < 16:
                first_width = 16

        # Step 3: Process and resize all images with consistent height
        processed_images = []
        total_width = 0

        for i, image in enumerate(pil_images):
            if i == 0:
                # First image uses the calculated dimensions
                final_width = first_width
            else:
                # Subsequent images maintain aspect ratio with the same height
                img_aspect_ratio = image.width / image.height
                # Calculate width that maintains aspect ratio at the target height
                calculated_width = int(final_height * img_aspect_ratio)
                # Ensure width is divisible by 16 for proper VAE encoding
                final_width = 16 * (calculated_width // 16)
                # Ensure minimum width
                if final_width < 16:
                    final_width = 16

            # Resize image to calculated dimensions
            resized_image = image.resize((final_width, final_height), Image.Resampling.LANCZOS)
            processed_images.append(resized_image)
            total_width += final_width

        # Step 4: Concatenate images horizontally
        concatenated_image = Image.new("RGB", (total_width, final_height))
        x_offset = 0
        for img in processed_images:
            concatenated_image.paste(img, (x_offset, 0))
            x_offset += img.width

        # Save the concatenated image
        image_dto = context.images.save(image=concatenated_image)
        return ImageOutput.build(image_dto)
